{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from transformers import ViTModel, ViTImageProcessor\n",
    "from utils import text_encoder_forward\n",
    "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
    "from utils import latents_to_images, downsampling, merge_and_save_images\n",
    "from omegaconf import OmegaConf\n",
    "from accelerate.utils import set_seed\n",
    "from tqdm import tqdm\n",
    "from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\n",
    "from PIL import Image\n",
    "from models.celeb_embeddings import embedding_forward\n",
    "import models.embedding_manager\n",
    "import importlib\n",
    "\n",
    "# seed = 42\n",
    "# set_seed(seed)  \n",
    "# torch.cuda.set_device(0)\n",
    "\n",
    "# set your sd2.1 path\n",
    "model_path = \"/home/user/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6\"\n",
    "pipe = StableDiffusionPipeline.from_pretrained(model_path)   \n",
    "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
    "pipe = pipe.to(\"cuda\")\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "vae = pipe.vae\n",
    "unet = pipe.unet\n",
    "text_encoder = pipe.text_encoder\n",
    "tokenizer = pipe.tokenizer\n",
    "scheduler = pipe.scheduler\n",
    "\n",
    "input_dim = 64\n",
    "\n",
    "experiment_name = \"normal_GAN\"   # \"normal_GAN\", \"man_GAN\", \"woman_GAN\" , \n",
    "if experiment_name == \"normal_GAN\":\n",
    "    steps = 10000\n",
    "elif experiment_name == \"man_GAN\":\n",
    "    steps = 7000\n",
    "elif experiment_name == \"woman_GAN\":\n",
    "    steps = 6000\n",
    "else:\n",
    "    print(\"Hello, please notice this ^_^\")\n",
    "    assert 0\n",
    "\n",
    "\n",
    "original_forward = text_encoder.text_model.embeddings.forward\n",
    "text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)\n",
    "embedding_manager_config = OmegaConf.load(\"datasets_face/identity_space.yaml\")\n",
    "Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain(  \n",
    "        tokenizer,\n",
    "        text_encoder,\n",
    "        device = device,\n",
    "        training = True,\n",
    "        experiment_name = experiment_name, \n",
    "        num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token,            \n",
    "        token_dim = embedding_manager_config.model.personalization_config.params.token_dim,\n",
    "        mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,\n",
    "        loss_type = embedding_manager_config.model.personalization_config.params.loss_type,\n",
    "        vit_out_dim = input_dim,\n",
    ")\n",
    "\n",
    "\n",
    "embedding_path = os.path.join(\"training_weight\", experiment_name, \"embeddings_manager-{}.pt\".format(str(steps)))\n",
    "Embedding_Manager.load(embedding_path)\n",
    "text_encoder.text_model.embeddings.forward = original_forward\n",
    "\n",
    "print(\"finish init\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. create a new character and test with prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample a z\n",
    "for index in range(100):\n",
    "\n",
    "    random_embedding = torch.randn(1, 1, input_dim).to(device)\n",
    "\n",
    "    # map z to pseudo identity embeddings\n",
    "    _, emb_dict = Embedding_Manager(tokenized_text=None, embedded_text=None, name_batch=None, random_embeddings = random_embedding, timesteps = None,)\n",
    "\n",
    "    test_emb = emb_dict[\"adained_total_embedding\"].to(device)\n",
    "\n",
    "    v1_emb = test_emb[:, 0]\n",
    "    v2_emb = test_emb[:, 1]\n",
    "    embeddings = [v1_emb, v2_emb]\n",
    "\n",
    "    save_dir = os.path.join(\"test_results/\" + experiment_name, str(index))\n",
    "    os.makedirs(save_dir, exist_ok=True)    \n",
    "    test_emb_path = os.path.join(save_dir, \"id_embeddings.pt\")\n",
    "    torch.save(test_emb, test_emb_path)\n",
    "\n",
    "\n",
    "\n",
    "    '''insert into tokenizer & embedding layer'''\n",
    "    tokens = [\"v1*\", \"v2*\"]\n",
    "    embeddings = [v1_emb, v2_emb]\n",
    "    # add tokens and get ids\n",
    "    tokenizer.add_tokens(tokens)\n",
    "    token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "    # resize token embeddings and set new embeddings\n",
    "    text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
    "    for token_id, embedding in zip(token_ids, embeddings):\n",
    "        text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
    "\n",
    "    prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
    "        \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
    "        \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
    "        \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
    "        \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
    "    ]\n",
    "\n",
    "    for prompt in prompts_list:\n",
    "        image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
    "        save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
    "        image.save(save_img_path)\n",
    "        print(save_img_path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. directly use a chosen generated pseudo identity embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the path of your generated embeddings\n",
    "test_emb_path = \"test_results/normal_GAN/0000/id_embeddings.pt\"\n",
    "test_emb = torch.load(test_emb_path).cuda()\n",
    "v1_emb = test_emb[:, 0]\n",
    "v2_emb = test_emb[:, 1]\n",
    "\n",
    "\n",
    "index = \"chosen_index\"\n",
    "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "\n",
    "'''insert into tokenizer & embedding layer'''\n",
    "tokens = [\"v1*\", \"v2*\"]\n",
    "embeddings = [v1_emb, v2_emb]\n",
    "# add tokens and get ids\n",
    "tokenizer.add_tokens(tokens)\n",
    "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "# resize token embeddings and set new embeddings\n",
    "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
    "for token_id, embedding in zip(token_ids, embeddings):\n",
    "    text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
    "\n",
    "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a purple wizard outfit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing headphones, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* with red hair, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing headphones with red hair, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a Christmas hat, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses and necklace, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a blue cap, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a doctoral cap, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* with white hair, wearing glasses, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* in a helmet and vest riding a motorcycle, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* holding a bottle of red wine, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* driving a bus in the desert, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* playing basketball, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* playing the violin, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* piloting a spaceship, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* riding a horse, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* coding in front of a computer, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* laughing on the lawn, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* frowning at the camera, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* happily smiling, looking at the camera, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* crying disappointedly, with tears flowing, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* playing the guitar in the view of left side, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* holding a bottle of red wine, upper body, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses and necklace, close-up, in the view of right side, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* riding a horse, in the view of the top, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a doctoral cap, upper body, with the left side of the face facing the camera, best quality, ultra high res\",\n",
    "    \"v1* v2* crying disappointedly, with tears flowing, with left side of the face facing the camera, best quality, ultra high res\",\n",
    "    \"v1* v2* sitting in front of the camera, with a beautiful purple sunset at the beach in the background, best quality, ultra high res\",\n",
    "    \"v1* v2* swimming in the pool, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* climbing a mountain, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* skiing on the snowy mountain, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* in the snow, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* in space wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
    "]\n",
    "\n",
    "for prompt in prompts_list:\n",
    "    image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
    "    save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
    "    image.save(save_img_path)\n",
    "    print(save_img_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lbl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
